-
Notifications
You must be signed in to change notification settings - Fork 17
Add databricks_automl to the conda env for arima and deepar #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a unit test
@apeforest the existing UT to log, load, and predict an ARIMA model passes, see here: https://github.com/databricks/automl/blob/branch-0.2.20.3/runtime/tests/automl_runtime/forecast/pmdarima/model_test.py#L432 I could add a test / check to see that the requirements.txt that are logged with "automl_runtime=0.2.20.x" for all 3 models? |
@@ -57,10 +57,14 @@ def mlflow_forecast_log_model(forecast_model: ForecastModel, | |||
:param forecast_model: Forecast model wrapper | |||
:param sample_input: sample input Dataframes for model inference | |||
""" | |||
# log the model without signature if infer_signature is failed. | |||
# TODO: we should not be logging without a signature since it cannot be registered to UC then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a JIRA ticket number to TODO:
e.g. # TODO(ML-XXXXX): we should not be logging...
@@ -25,8 +25,10 @@ | |||
from mlflow.protos.databricks_pb2 import ErrorCode, INVALID_PARAMETER_VALUE | |||
from pmdarima.arima import ARIMA | |||
|
|||
from databricks.automl_runtime.forecast.pmdarima.model import ArimaModel, MultiSeriesArimaModel, AbstractArimaModel, \ | |||
mlflow_arima_log_model | |||
from databricks.automl_runtime.forecast.pmdarima.model import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please break imports into multiple lines:
pred_df = loaded_model.predict(sample_input) | ||
|
||
assert pred_df.columns.tolist() == [time_col, "yhat", id_col] | ||
assert len(pred_df) == self.prediction_length * 2 | ||
assert pred_df[time_col].min() > sample_input[time_col].max() | ||
|
||
def _check_requirements(self, run_id: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if a method is being tested in a TestClass, make it a public method (i.e. without prefix _)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a unit test itself, it's a helper method called only within other unit test cases, should I still remove prefix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sabhya-db. A few nits. Otherwise, LGTM Thanks
Thanks for checking, will fix nits soon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Do you know if it's a serverless-only issue, or should we also port this to the main branch?
We should do this on the main branch too, an ARIMA model logged from the main branch won't be able to be served at present. For a quick hotfix you can just add "automl_runtime=xyz" to ARIMA_CONDA_ENV here: sabhya-db/ML-45532-sample-weight-2 |
Can you also port the changes for ARIMA to the main branch then? |
will do in a separate PR |
It was already passed in for Prophet. Adding for ARIMA and DeepAR, too, so we can log correctly.